{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found local copy...\n",
      "generating training, validation splits...\n",
      "100%|██████████| 728/728 [00:00<00:00, 1378.45it/s]\n",
      "Using backend: pytorch\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drug Property Prediction Mode...\n",
      "in total: 637 drugs\n",
      "encoding drug...\n",
      "unique drugs: 634\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 91 drugs\n",
      "encoding drug...\n",
      "unique drugs: 91\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 182 drugs\n",
      "encoding drug...\n",
      "unique drugs: 181\n",
      "do not do train/test split on the data for already splitted data\n",
      "Let's use CPU/s!\n",
      "--- Data Preparation ---\n",
      "--- Go for Training ---\n",
      "Training at Epoch 1 iteration 0 with loss 33.4654. Total time 0.0 hours\n",
      "Validation at Epoch 1 , MSE: 1.84570 , Pearson Correlation: 0.28794 with p-value: 5.65E-03 , Concordance Index: 0.60782\n",
      "Training at Epoch 2 iteration 0 with loss 4.77358. Total time 0.00055 hours\n",
      "Validation at Epoch 2 , MSE: 3.82274 , Pearson Correlation: 0.47650 with p-value: 1.80E-06 , Concordance Index: 0.65207\n",
      "Training at Epoch 3 iteration 0 with loss 2.30286. Total time 0.00138 hours\n",
      "Validation at Epoch 3 , MSE: 2.37858 , Pearson Correlation: 0.44194 with p-value: 1.16E-05 , Concordance Index: 0.63789\n",
      "Training at Epoch 4 iteration 0 with loss 1.85074. Total time 0.00194 hours\n",
      "Validation at Epoch 4 , MSE: 1.72014 , Pearson Correlation: 0.57234 with p-value: 3.08E-09 , Concordance Index: 0.68007\n",
      "Training at Epoch 5 iteration 0 with loss 0.97545. Total time 0.00277 hours\n",
      "Validation at Epoch 5 , MSE: 1.26116 , Pearson Correlation: 0.58364 with p-value: 1.26E-09 , Concordance Index: 0.69278\n",
      "Training at Epoch 6 iteration 0 with loss 0.74741. Total time 0.00333 hours\n",
      "Validation at Epoch 6 , MSE: 1.13191 , Pearson Correlation: 0.67505 with p-value: 2.17E-13 , Concordance Index: 0.73080\n",
      "Training at Epoch 7 iteration 0 with loss 0.67488. Total time 0.00416 hours\n",
      "Validation at Epoch 7 , MSE: 0.91687 , Pearson Correlation: 0.68444 with p-value: 7.44E-14 , Concordance Index: 0.72249\n",
      "Training at Epoch 8 iteration 0 with loss 0.54359. Total time 0.00472 hours\n",
      "Validation at Epoch 8 , MSE: 0.80481 , Pearson Correlation: 0.65949 with p-value: 1.17E-12 , Concordance Index: 0.69938\n",
      "Training at Epoch 9 iteration 0 with loss 0.35596. Total time 0.00527 hours\n",
      "Validation at Epoch 9 , MSE: 0.96110 , Pearson Correlation: 0.68223 with p-value: 9.60E-14 , Concordance Index: 0.71784\n",
      "Training at Epoch 10 iteration 0 with loss 0.55490. Total time 0.00611 hours\n",
      "Validation at Epoch 10 , MSE: 0.78065 , Pearson Correlation: 0.68377 with p-value: 8.04E-14 , Concordance Index: 0.71931\n",
      "--- Go for Testing ---\n",
      "Testing MSE: 1.3013522857612316 , Pearson Correlation: 0.4783977334968594 with p-value: 8.47E-12 , Concordance Index: 0.5889595411757371\n",
      "--- Training Finished ---\n",
      "predicting...\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAEPCAYAAACgFqixAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAfNUlEQVR4nO3de5RcZZnv8e8vF0iHhNxoIGlIMl5HZZwALQfUg1wcRccRFR1HMwgDY8TLERzOHEHGY1BAXUtBVJxDVCRI6+hBHBlEFFBw8AjagYBgZAKaRCAkARITCBeTPOePd5epdKo6VZ2qvatr/z5r1erab+2q/Wzo1NPvXRGBmZmV25iiAzAzs+I5GZiZmZOBmZk5GZiZGU4GZmaGk4GZmZFzMpA0QdIvJN0l6V5J52bll0v6naSl2WNennGZmZXduJyv9wxwTEQ8IWk8cKukH2Sv/XNEXJVzPGZmRs7JINIMtyeyw/HZw7PezMwKprxnIEsaCywBngdcEhEflnQ5cASp5nATcFZEPDPc5+yzzz4xd+7cNkdrZtZdlixZ8mhE9A4tzz0Z/OnC0lTgu8D/AB4DHgH2ABYBD0TEx2u8ZwGwAGD27NmHrly5Mr+Azcy6gKQlEdE/tLyw0UQRsQG4GTguIlZH8gzwNeCwOu9ZFBH9EdHf27tTYjMzsxHKezRRb1YjQFIP8GrgN5JmZmUC3gTck2dcZmZll/doopnA4qzfYAzw7Yi4VtKPJfUCApYCp+Ucl5lZqeU9muhu4OAa5cfkGYeZme3IM5DNzKxcyWBgAObOhTFj0s+BgaIjMjPrDHn3GRRmYAAWLIDNm9PxypXpGGD+/OLiMjPrBKWpGZxzzvZEULF5cyo3Myu70iSDVauaKzczK5PSJIPZs5srNzMrk9Ikg/PPh4kTdyybODGVm5mVXWmSwfz5sGgRTJ+ejvv60rE7j83MSjSaCNIX//jx8Pa3w/XXw0EHFR2RmVlnKE3NoGLq1PTzD38oNg4zs05SumQwZUr66WRgZradk4GZmTkZmJlZiZPBhg3FxmFm1klKlwx6etKIItcMzMy2K10ykFLtwMnAzGy70iUDcDIwMxuqtMnAfQZmZtuVNhm4ZmBmtl0pk8HUqU4GZmbVSpkMXDMwM9tRrslA0gRJv5B0l6R7JZ2blf+ZpNslLZf0LUl7tDMOJwMzsx3lXTN4BjgmIv4SmAccJ+lw4NPARRHxfGA9cGo7g5gyBTZuhK1b23kVM7PRI9dkEMkT2eH47BHAMcBVWfli4E3tjKMyC3nTpnZexcxs9Mi9z0DSWElLgbXADcADwIaI2JKd8iDQ184YvIy1mdmOck8GEbE1IuYBBwCHAS+qdVqt90paIGlQ0uC6detGHIMXqzMz21Fho4kiYgNwM3A4MFVSZde1A4CH67xnUUT0R0R/b2/viK/tZGBmtqO8RxP1SpqaPe8BXg0sA34CvDU77STge+2MwyuXmpntKO89kGcCiyWNJSWib0fEtZJ+DfybpPOAO4GvtjMI1wzMzHaUazKIiLuBg2uU/5bUf5ALdyCbme2otDOQwcnAzKyilMlgzz3Tw8nAzCwpZTIAL2NtZlat1MnANQMzs6S0ycDLWJuZbVfaZOCagZnZdk4GZmZW7mTgDmQzs6TUycA1AzOzpLTJYOpUePJJ2LJl1+eamXW70iaDyizkjRuLjcPMrBOUPhm4qcjMzMnAnchmZjgZuGZgZkaJk4GXsTYz2660ycA1AzOz7ZwMnAzMzJwM3IFsZlbiZDB+PPT0uGZgZgYlTgbgZazNzCpKnQy8PpGZWZJrMpB0oKSfSFom6V5Jp2flCyU9JGlp9nh9HvE4GZiZJeNyvt4W4MyIuEPSZGCJpBuy1y6KiM/kGYyXsTYzS3KtGUTE6oi4I3u+CVgG9OUZQzXXDMzMksL6DCTNBQ4Gbs+KPiDpbkmXSZqWRwzuQDYzSwpJBpImAd8BzoiIjcC/As8F5gGrgc/Wed8CSYOSBtetW7fbcbhmYGaW5J4MJI0nJYKBiLgaICLWRMTWiNgGfBk4rNZ7I2JRRPRHRH9vb+9uxzJlCjz1FDz77G5/lJnZqJb3aCIBXwWWRcSFVeUzq057M3BPHvF4SQozsyTv0USvAE4EfiVpaVb2EeAdkuYBAawA3pNHMNXJoAUVDTOzUSvXZBARtwKq8dJ1ecZR4WWszcyS0s9ABicDMzMnA5wMzMycDPAsZDMzJwNcMzAzK3Uy2Hvv9NPJwMzKruFkIOlgSVdLelTSFkmHZOUXSDqufSG2z7hxMGmSk4GZWUPJQNIrgZ8Dfw58Y8j7tgGntT60fHhJCjOzxmsGnwJ+CLwE+Kchr90BHNLKoPLkZazNzBqfdHYI8JaICEkx5LVHgVE7f9c1AzOzxmsGTwMT67w2Exi1X6dextrMrPFkcCtwhqSxVWWVGsKpwI9bGlWOXDMwM2u8meijwM+Au4CrSIngJEkXAocCL2tPeO3nZGBm1mDNICLuAo4E1gDnkBab+0D28qsi4r72hNd+7kA2M2ti1dJs7+JjJU0ApgMbImJz2yLLyZQpaXObp5+GCROKjsbMrBhNz0COiKcj4uFuSATgZazNzKDBmoGk/72LUyIiPtGCeHJXvT7RfvsVG4uZWVEabSZaOMxrlVFFoz4ZmJmVVaMdyGOGPoAZwMmk/Yqf18YY28rLWJuZ7ca2lxGxHrhC0gzgEuD1LYsqR64ZmJm1ZgnryrDTUckdyGZmrUkGbwDWteBzCuGagZlZ46OJLqtRvAdwEPAXwMca/JwDgSuA/UlLXy+KiIslTQe+BcwFVgB/mzVDtd3kySA5GZhZuTXaZ3AM20cNVTwNrAQ+Byxu8HO2AGdGxB2SJgNLJN1A6oi+KSI+Jeks4Czgww1+5m4ZMyYlBHcgm1mZNZQMImJuKy4WEauB1dnzTZKWAX3A8cBR2WmLgZvJKRmA1ycyMytsD2RJc4GDgduB/bJEUUkY+9Z5zwJJg5IG161rXTeFl7E2s7KrWzOQ1NQIoYj4aaPnSpoEfAc4IyI2Smr0GouARQD9/f1Dm61GzDUDMyu74ZqJbmbnfoJalJ03dlcnAkgaT0oEAxFxdVa8RtLMiFgtaSawtpHPapUpU2D16jyvaGbWWYZLBke3+mJKVYCvAssi4sKql64BTiLttXwS8L1WX3s4U6bAsmV5XtHMrLPUTQYRcUsbrvcK4ETgV5KWZmUfISWBb0s6FVgFvK0N167LzURmVnYjXo5iJCLiVlKzUi3H5hlLtUoHckSac2BmVjYNJwNJB5H2O34hMHQbmIiIwr7Md9eUKbBlCzz1FEycWHQ0Zmb5a3QG8n8DbiHNDn4+cDcwDZgNPAjc36b4clG9JIWTgZmVUaPzDC4ArgZeQmrmOTWbiPZq0iii89oSXU68jLWZlV2jyeClwJVsH2o6FiAifkxKBJ9sfWj58WJ1ZlZ2jSaD8cCTEbENeByYWfXafaQF60YtL2NtZmXXaDJ4gLSGEKT+glMkjZE0BvgH4JF2BJcX1wzMrOwaHU30H6SF5L5B6j/4PrAR2ApMAj7YjuDy4mRgZmXX6KqlC6ue3yjpcOAEYCJwfUT8qD3h5cMdyGZWdiOadBYRdwJ3tjiWwkyalPY1cM3AzMqqoT4DSVdLelO2yFzXkWDvvZ0MzKy8Gu1A/nPSPIPVki7Jmom6ivc0MLMyaygZRMSLgZeR5hq8BfiZpOWSPirpOe0MMC9erM7Myqzhnc4iYklEnAEcAPwN8EvS1pTLJf1nm+LLzZQp7kA2s/JqetvLiNgaEddFxDtJtYSHgZe3PLKcuWZgZmXW9GgiSc8F/h6YDzyXtMH9Z1scV+6cDMyszBpdtXQa8HbSxjSHA5uB7wLvB26MiJbtR1wUdyCbWZk12kz0CPBF4EngZGC/iHhXRNzQDYlgYACuuALWr4c5c9KxmVmZNNpM9C/AlRHRddvGDwzAggWweXM6XrUqHQPMn19cXGZmedJo/cO+v78/BgcHd/tz5s6FlSt3Lp8zB1as2O2PNzPrKJKWRET/0PKmRxN1m1Wrmis3M+tGuSYDSZdJWivpnqqyhZIekrQ0e7w+z5hmz26u3MysG+VdM7gcOK5G+UURMS97XJdnQOefv/O+xxMnpnIzs7LINRlExE9JO6V1jPnzYdGi7TWBKVPSsTuPzaxMOqXP4AOS7s6akablffH581Mn8syZcMIJTgRmVj6NLmF9vKR/qDqeI+nnkjZJukrSpN2I4V9JM5nnsYvZzJIWSBqUNLhu3brduGRtfX3w0EMt/1gzs47XaM3gX4DequMLSQvWLQKOBBaONICIWJOtd7QN+DJw2DDnLoqI/ojo7+3trXfaiPX1wcMPt/xjzcw6XqPJ4LnA3QCSeoDXA/8UEWcCHwHePNIAJM2sOnwzcE+9c9tt1izXDMysnBqdgTwBeCp7/vLsfZV9j+8DZjXyIZK+CRwF7CPpQeBjwFGS5gEBrADe02BMLdfXB48/Dk89BT09RUVhZpa/RpPBCuCVwC3A8cCSiKgs67Yv0NASbxHxjhrFX20whrablaW01avhOV2xZY+ZWWMabSa6FFgoaRB4Hzt+gR8B/LrVgRWhry/9dFORmZVNQzWDiLhY0qOk5as/HxFXVL08GfhaO4LLm5OBmZVVw5vbRMQAsNPizhFRWBt/q1WaiTyiyMzKptF5Bi+QdFjVcY+kT0r6D0kfaF94+Zo6NXUcu2ZgZmXTaJ/BF4G3Vh2fD5xJGkV0kaT3tzqwIkipduCagZmVTaPJ4KXAzwAkjQHeBXw4Ig4FzgMWtCe8/HkWspmVUaPJYCrwWPb8YGAacFV2fDPQNQMxnQzMrIwaTQZrgOdlz18DPBARv8+OJwFbWh1YUSrNRKN0AzgzsxFpdDTRNcAnJR0EnEyad1DxF8BvWxxXYfr64OmnYf16mD696GjMzPLRaDI4i7QkxWtJieGCqtfeyPalKUa96uGlTgZmVhaNTjp7Enh3ndde3tKIClY98eygg4qNxcwsLw1POgOQNJ20/MR0UofybRHRUTuX7S7PQjazMmo4GUg6jzS3YM+q4mckfSYiPtryyAoyM1tQ23MNzKxMGp2BfAZp34IrgaOBF2U/rwQ+IumDbYswZxMmwIwZrhmYWbk0WjM4Dbg4Ij5UVXYfcIukJ0grmX6+1cEVxZvcmFnZNDrPYC7w/TqvfT97vWt4+0szK5tGk8FjQL2xNS9h++zkruBZyGZWNo0mg+8Cn5B0oqTxAJLGSXoH8HHgO+0KsAizZsGaNbCla+ZVm5kNr9FkcDawFFgMbJa0hrQn8gBwF6lzuWv09aXlKB55pOhIzMzy0eiks02SjgT+GvjvpHkGj5P2RP5BRHet5FM91+CAA4qNxcwsD83sdBbAtdmjq3nHMzMrm0abiVpC0mWS1kq6p6psuqQbJC3Pfk7LM6ZaPAvZzMqmbjKQtE3S1gYfjXa1Xg4cN6TsLOCmiHg+cFN2XKjeXhg3zjUDMyuP4ZqJPg60tC8gIn4qae6Q4uOBo7Lni0mb5Xy4lddt1pgxaVkK1wzMrCzqJoOIWJhTDPtFxOrsmqsl7ZvTdYfluQZmVia59hnsLkkLJA1KGly3bl1br1XZ8czMrAw6IRmskTQTIPu5tt6JEbEoIvojor+3t7etQblmYGZl0gnJ4BrgpOz5ScD3CozlT2bNgo0b4Yknio7EzKz98h5a+k3g58ALJT0o6VTgU8BfSVoO/FV2XLjK8FI3FZlZGTS109nuioh31Hnp2DzjaET1XIMXvKDYWMzM2q0Tmok6kmchm1mZOBnU4VnIZlYmTgZ1TJ4MkyY5GZhZOTgZDMM7nplZWTgZDMNzDcysLJwMhjFrlpOBmZWDk8Ew+vpg9WrYtq3oSMzM2svJYBizZsEf/wiPPlp0JGZm7eVkMAzPQjazsnAyGIbnGphZWTgZDKMyC9nJwMy6nZPBMGbOBMnNRGbW/ZwMhjF+POy7r2sGZtb9nAx2wTuemVkZOBnsgmchm1kZOBnsgmchm1kZOBnsQl9fmnT2zDNFR2Jm1j5OBrvw+9+nnz09MHcuDAwUGo6ZWVs4GQxjYAC+/vX0PAJWroQFC5wQzKz7OBkM45xzdm4e2rw5lZuZdRMng2GsWtVcuZnZaDWu6AAqJK0ANgFbgS0R0V9sRDB7dmoaqlVuZtZNOq1mcHREzOuERABw/vkwceKOZRMnpnIzs27Sacmgo8yfD4sWwZw5aY0igHe+M5U3a2AgjUYaM8ajksys83RSMgjgR5KWSFpQdDAV8+fDihWwdSu87GVw443w7LPNfcbAQBqFtHKlRyWZWWfqpGTwiog4BHgd8H5JRw49QdICSYOSBtetW5drcBIsXJgSw+LFzb33nHPSKKRqHpVkZp1EEVF0DDuRtBB4IiI+U++c/v7+GBwczC8o0l/1RxyR9kVevhz22KOx940Zk947lOT9lc0sX5KW1OqX7YiagaS9JE2uPAdeA9xTbFQ7q9QOVq2Cyy5r7D3Ll6elsGvxqCQz6xQdkQyA/YBbJd0F/AL4fkRcX3BMNb32tal2cP75tdcrqu4o3mcfOOggGDcO9txz53NPOKHt4ZqZNaQjm4kaUUQzUcUNN8BrXgPTpsGGDekv/Mpw0wULduwfGDMGLroIZsxIfQSrVsEBB6QEsWYN3HIL9HfEQFozK4N6zUQdM+lsNFm7Nn3Jr1+fjleuhFNOSc1IQ2sL27bBhRemjufqIamPPJJqGMceC5Mnpw10KkllJENXzcx2R6c0E40q55yzc8fvs8/WX+a61vIV++8P73sfbNyY9kvwkFMzK5KTwQg0uzZRvY7iSy7Zuax6yKknqplZXpwMRqDel/uMGc0tX1EvqaxcCZ//vCeqmVl+nAxGoN6aRRdfvOPyFXPmpON6fQDDDS09/XRPVDOz/DgZjMDQNYuqv/Qry1ds27Zzp/FQ9ZLKhRfWf4+XzzazdnAyGKFmvvSH+4xaSeVDH0rPa/FENTNrByeDgtVLKrVqDT09Xj7bzNrDyaBD1Vo+++ijW7t8tkcrmVmFk0EHq641nHYaXH893HZb/fNrfbnXWz773e9Oj2ZHKzmBmHUnL0cxSmzcmNY5mjQJ7rxz57WOKl/61SOQenrSshebNjV+nXpbfda7xsSJw4+YMrPO0tGrltqu7b13+tJdtgw+8YmdX6+1Z8JTTzWXCCCNVrrmmto1gLPPrj/c1U1RZqNcRIzKx6GHHhpldPLJEVLEzJnp55w5EV/4QkRq7Gn8MXZs7fJx49JPacfyMWOG/7yenh2PJ06MeO9708+h5Vdeme7lyitT/JX72FW5me0+YDBqfKcW/qU+0kdZk8Gllzb3pT9jRu0v5Hpf1Fdckd5T67OGJoiRPqZPj7j44ubickIwa416ycDNRKPMBRfULp86tblZ0V/6Uu3yE0+Exx+vfY2I2tdo1uOP159hfemlzTdFmVkL1MoQo+FR1ppBvb/OpdY1r8yZU/salc8ceo1659dripo5c2Q1ivHja9cYWtms5CYq63a4mag7DPdF3SpXXtlcU02984dr8mk2gdR77LVXxIQJzSeJWq8Nd9959G84EVkenAy6RLNf1LtznWa+mJr9smw2gTRbi5g0KWLPPet/sQ/9zD333LkTvPp9tT5ruGTX7YnIiWv0cjLoIt3yD7GZL7l6NYlmHz09KVG04rPqPaZOHf6LfehrEyZETJlS+7OmTx9ZR3sz/22bTUS7+oOk3X9ItPJ3rZVGy79LJwMb1ep9AdUb+dSJj3HjWjciq97n7LdfxMKFtZvOaiWQnp76iahWUuvpSeW1zp89u3WJZbTVuoqu2TXDycBGvWb+oddLEnPm1K9l1BuGW++zmu3fKMNjjz1ql0+ZsnMz3IQJEXvvXfv8evNa6tWUhvtdmDatuc8aSSI68MD2X6Pev4FmdXwyAI4D7gPuB87a1flOBlYxkr8G2/nXa5GJaN99W1f7aPYxeXIx14X0hT99enuvMWHCziPaxo6N2H//5j+r3v+jvfbauX+qpyfN/2lVf2FHJwNgLPAA8BxgD+Au4MXDvcfJwHallX9htbINvt3NKM2O1Go2EdU7v5V9O6Op1jVcU1srH/USSLMjCTs9GRwB/LDq+Gzg7OHe42RgnayoRFQpb0XTxEjbwVuRWJqtdR14YERfX3PXaFXzX2WOT1FNjFJzv5udngzeCnyl6vhE4IvDvcfJwKy+VnaMNttG3c5O3E5t/ms23lY2MXZbzeBtNZLBF2qctwAYBAZnz57d3H8BMytUq0bPdGLzXx7XKEufgZuJzKzj5TEctKjRRB2xuY2kccB/AccCDwG/BN4ZEffWe0/ZNrcxM2uFepvbjCsimKEiYoukDwA/JI0sumy4RGBmZq3VEckAICKuA64rOg4zszLyfgZmZuZkYGZmTgZmZgadMZpoJCStA1bu4rR9gEdzCKfT+L7LxfddPrtz73Miondo4ahNBo2QNFhrCFW3832Xi++7fNpx724mMjMzJwMzM+v+ZLCo6AAK4vsuF993+bT83ru6z8DMzBrT7TUDMzNrQFcmA0nHSbpP0v2Szio6nnaSdJmktZLuqSqbLukGScuzn9OKjLEdJB0o6SeSlkm6V9LpWXlX37ukCZJ+Iemu7L7Pzcr/TNLt2X1/S9IeRcfaDpLGSrpT0rXZcdfft6QVkn4laamkways5b/nXZcMJI0FLgFeB7wYeIekFxcbVVtdTto/utpZwE0R8Xzgpuy422wBzoyIFwGHA+/P/j93+70/AxwTEX8JzAOOk3Q48Gngouy+1wOnFhhjO50OLKs6Lst9Hx0R86qGk7b897zrkgFwGHB/RPw2Ip4F/g04vuCY2iYifgo8PqT4eGBx9nwx8KZcg8pBRKyOiDuy55tIXxB9dPm9Z0vSP5Edjs8eARwDXJWVd919A0g6APhr4CvZsSjBfdfR8t/zbkwGfcDvq44fzMrKZL+IWA3pSxPYt+B42krSXOBg4HZKcO9ZU8lSYC1wA/AAsCEitmSndOvv/OeA/wVsy45nUI77DuBHkpZIWpCVtfz3vGOWsG4h1SjzkKkuJWkS8B3gjIjYmP5Y7G4RsRWYJ2kq8F3gRbVOyzeq9pL0BmBtRCyRdFSluMapXXXfmVdExMOS9gVukPSbdlykG2sGDwIHVh0fADxcUCxFWSNpJkD2c23B8bSFpPGkRDAQEVdnxaW4d4CI2ADcTOozmZrtGAjd+Tv/CuCNklaQmn6PIdUUuv2+iYiHs59rScn/MNrwe96NyeCXwPOzUQZ7AH8HXFNwTHm7Bjgpe34S8L0CY2mLrL34q8CyiLiw6qWuvndJvVmNAEk9wKtJ/SU/Ad6andZ19x0RZ0fEARExl/Rv+scRMZ8uv29Je0maXHkOvAa4hzb8nnflpDNJryf91VDZQvP8gkNqG0nfBI4irWK4BvgY8O/At4HZwCrgbRExtJN5VJP0SuA/gV+xvQ35I6R+g669d0kvJXUYjiX9MfftiPi4pOeQ/mKeDtwJ/H1EPFNcpO2TNRP9z4h4Q7ffd3Z/380OxwHfiIjzJc2gxb/nXZkMzMysOd3YTGRmZk1yMjAzMycDMzNzMjAzM5wMzMwMJwPrQpIWSors+dTs+JAC45mXxTC9xmshaWEBYZntwMnAutFXgCOy51NJcy8KSwak1UU/RhoLP9QRZAuvmRWpG9cmspKLiAdJy5K0RTb7eXy2Ku5uiYjbWhCS2W5zzcC6TqWZKFvN9HdZ8ZezspB0ctW5b5F0m6TNkjZI+r+SZg/5vBWSrpR0SrZI2LOkpZSRdK6kOyT9QdKjkn6c7S9Qee/JwNeyw+VVMczNXt+pmUhpc6afS3oq+9x/l/TCIefcLOlWSa/Orr9Z0j2SyrKEs7WYk4F1s9XAW7LnnyQ1yRwBfB9A0mmkhe5+TVrf5j3AQcAtlfVgqhwN/BNwLmkzobuz8j7gItJ68ieTFgz7abZsBNm1zsuev60qhtW1ApZ0XPaeJ4C3A+/NYrpV0tDlmZ8LXAxcmN3nauAqSc8b9r+KWQ1uJrKuFRHPSLozO/xtdZNMtvT1p4GvRcQpVeW3A/9F2jHrc1UfNw04NCIeGXKNf6x671jgeuDe7P2nR8Q6SQ9kpyyNiPt3EfZ5wG+B11XW6Zf08yymM0kJqWIf4MiIWJ6ddwcpIfwtcMEurmO2A9cMrKyOAPYGBiSNqzxIfQ2/AY4ccv5tQxMBQNZM8xNJj5G24vwj8ALghUPP3ZVsVcpDgG9VbdhCRPwO+BnwqiFvWV5JBNl5a0k1k9mYNck1Ayurys5QN9Z5ff2Q452adbLhqtcBPyTVBFYDW0mjgyaMIKZppA1bajUhPQLMGVJWa5XKZ0Z4bSs5JwMrq8eynyeTmnWG2jTkuNbyvieQagNviYg/VgolTQM2jCCm9dl19q/x2v5sj9ms5ZwMrNtV1rbvGVL+/0hf+M+LiMWMzERSTeBPiULSMaRmmt9VnVcvhh1ExJOSlgBvk7Qw294SSXOAlwNfGGGcZrvkZGDdbg3pL+q/k3Q38CTwu4h4TNI/A5dI6gV+APyBNDroVcDNEfGNXXz29cAZwOWSvkbqK/go8NCQ836d/Xy/pMWkfoW768xT+ChpNNG1kr4ETCKNYPoD8Nkm7tusKe5Atq4WEduAfyS1x99I2hb1b7LXLgXeSOrs/TopIZxL+iNpaQOf/UPgg6T9ea8FTgHeBdw/5Ly7gIXZdW/NYphV5zOvJ81hmErayer/kLa1fGVlL1yzdvBOZ2Zm5pqBmZk5GZiZGU4GZmaGk4GZmeFkYGZmOBmYmRlOBmZmhpOBmZnhZGBmZsD/B+JOMITKuC/PAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "## if you are using the pip version, please comment out the below two lines\n",
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "from DeepPurpose import CompoundPred as models\n",
    "from DeepPurpose.utils import *\n",
    "from tdc import BenchmarkGroup\n",
    "group = BenchmarkGroup(name = 'ADMET_Group', path = 'data/')\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "## 0.1.2 new supported models: \n",
    "## DGL_GCN, DGL_NeuralFP, DGL_GIN_AttrMasking, DGL_GIN_ContextPred, DGL_AttentiveFP  \n",
    "drug_encoding = 'DGL_GCN'\n",
    "    \n",
    "benchmark = group.get('Caco2_Wang')\n",
    "\n",
    "train, valid = group.get_train_valid_split(benchmark = benchmark['name'], split_type = 'default', seed = 1)\n",
    "\n",
    "train = data_process(X_drug = train.Drug.values, y = train.Y.values, \n",
    "                drug_encoding = drug_encoding,\n",
    "                split_method='no_split')\n",
    "\n",
    "val = data_process(X_drug = valid.Drug.values, y = valid.Y.values, \n",
    "                drug_encoding = drug_encoding,\n",
    "                split_method='no_split')\n",
    "\n",
    "test = data_process(X_drug = benchmark['test'].Drug.values, y = benchmark['test'].Y.values, \n",
    "                drug_encoding = drug_encoding,\n",
    "                split_method='no_split')\n",
    "\n",
    "config = generate_config(drug_encoding = drug_encoding, \n",
    "                         cls_hidden_dims = [512], \n",
    "                         train_epoch = 10, \n",
    "                         LR = 0.001, \n",
    "                         batch_size = 128,\n",
    "                        )\n",
    "\n",
    "model = models.model_initialize(**config)\n",
    "model.train(train, val, test, verbose = True)\n",
    "y_pred = model.predict(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:DeepPurpose]",
   "language": "python",
   "name": "conda-env-DeepPurpose-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
